from functools import partial

import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator

from .template import HardTemplate


class DataManager:
    """
    数据管理器
    """

    def __init__(self, config, tokenizer):
        self.config = config
        self.tokenizer = tokenizer
        self.hard_template = HardTemplate(prompt=self.config.prompt)

    def get_data_loder(self):
        # 模板转换器定义
        dataset = load_dataset('text', data_files={'train': self.config.train_path,
                                                   'dev': self.config.valid_path})
        new_func = partial(self.convert_example,
                           tokenizer=self.tokenizer,
                           hard_template=self.hard_template,
                           max_seq_len=self.config.max_seq_len,
                           max_label_len=self.config.max_label_len)

        dataset = dataset.map(new_func, batched=True)

        # self.convert_example(dataset, tokenizer=self.tokenizer,
        #                      hard_template=self.hard_template,
        #                      max_seq_len=self.config.max_seq_len,
        #                      max_label_len=self.config.max_label_len)

        train_dataset = dataset["train"]
        dev_dataset = dataset["dev"]
        train_dataloader = DataLoader(train_dataset,
                                      shuffle=True,
                                      collate_fn=default_data_collator,
                                      batch_size=self.config.batch_size)
        dev_dataloader = DataLoader(dev_dataset,
                                    collate_fn=default_data_collator,
                                    batch_size=self.config.batch_size)
        return train_dataloader, dev_dataloader

    @staticmethod
    def convert_example(examples: dict,
                        tokenizer,
                        max_seq_len: int,
                        max_label_len: int,
                        hard_template: HardTemplate,
                        train_mode=True,
                        return_tensor=False) -> dict:
        """
        将样本数据转换为模型接收的输入数据。

        Args:
            examples (dict): 训练数据样本, e.g. -> {
                                                    "text": [
                                                                '手机	这个手机也太卡了。',
                                                                '体育	世界杯为何迟迟不见宣传',
                                                                ...
                                                    ]
                                                }
            tokenizer:
            max_seq_len (int): 句子的最大长度，若没有达到最大长度，则padding为最大长度
            max_label_len (int): 最大label长度，若没有达到最大长度，则padding为最大长度
            hard_template (HardTemplate): 模板类。
            train_mode (bool): 训练阶段 or 推理阶段。
            return_tensor (bool): 是否返回tensor类型，如不是，则返回numpy类型。

        Returns:
            dict (str: np.array) -> tokenized_output = {
                                'input_ids': [[1, 47, 10, 7, 304, 3, 3, 3, 3, 47, 27, 247, 98, 105, 512, 777, 15, 12043, 2], ...],
                                'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ...],
                                'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], ...],
                                'mask_positions': [[5, 6, 7, 8], ...],
                                'mask_labels': [[2372, 3442, 0, 0], [2643, 4434, 2334, 0], ...]
                            }
        """
        tokenized_output = {
            'input_ids': [],
            'token_type_ids': [],
            'attention_mask': [],
            'mask_positions': [],
            'mask_labels': []
        }

        for i, example in enumerate(examples['text']):
            if train_mode:
                # print(f'example-->{example}')
                label, content = example.strip().split('\t')
                # print(f'label-->{label}')
                # print(f'content-->{content}')
            else:
                content = example.strip()

            inputs_dict = {
                'textA': content,

                'MASK': '[MASK]'
            }
            # print(f'inputs_dict-->{inputs_dict}')

            encoded_inputs = hard_template(
                inputs_dict=inputs_dict,
                tokenizer=tokenizer,
                max_seq_len=max_seq_len,
                mask_length=max_label_len)
            tokenized_output['input_ids'].append(encoded_inputs["input_ids"])
            tokenized_output['token_type_ids'].append(encoded_inputs["token_type_ids"])
            tokenized_output['attention_mask'].append(encoded_inputs["attention_mask"])
            tokenized_output['mask_positions'].append(encoded_inputs["mask_position"])
            # print(f'tokenized_output-->{tokenized_output}')

            if train_mode:
                print(f"label----------->{label}")
                label_encoded = tokenizer(text=[label])  # 将label补到最大长度

                label_encoded = label_encoded['input_ids'][0][1:-1]
                # print(f'label_encoded-->{label_encoded}')

                label_encoded = label_encoded[:max_label_len]
                # print(f'tokenizer.pad_token_id-->{tokenizer.pad_token_id}')
                label_encoded = label_encoded + [tokenizer.pad_token_id] * (max_label_len - len(label_encoded))

                tokenized_output['mask_labels'].append(label_encoded)

        for k, v in tokenized_output.items():
            if return_tensor:
                tokenized_output[k] = torch.LongTensor(v)
            else:
                tokenized_output[k] = np.array(v)

        return tokenized_output
